#!/usr/bin/env python
# coding: utf-8

# In[34]:


from turtle import st
import pandas as pd
import numpy as np
from scipy.signal import savgol_filter
import matplotlib as mpl
from scipy import interpolate
import statsmodels.api as sm

mpl.rcParams['ytick.labelsize'] = 15

def get_ts_data(TSfile,variable): # Function to read tslist
        
        # column names as defined by the WRFV3/run/README.tslist
        col_names = ['id','ts_hour','id_tsloc','ix','iy','t','q','u','v','psfc',
                    'glw','gsw','hfx','lh','tsk','tsbl','rainc','rainnc','clw']
        
        # check that the input variable matches with one in the list
        if variable not in col_names:
            print ("That variable is not available. Choose a variable from the following list")
            print ("        'id'           grid ID\n        'ts_hour':     forecast time in hours\n        'id_tsloc':    time series ID\n        'ix':          grid location (nearest grid to the station)\n        'iy':          grid location (nearest grid to the station)\n        't':           2 m Temperature (K)\n        'q':           2 m vapor mixing ratio (kg/kg)\n        'u':           10 m U wind (earth-relative)\n        'v':           10 m V wind (earth-relative)\n        'psfc':        surface pressure (Pa)\n        'glw':         downward longwave radiation flux at the ground (W/m^2, downward is positive)\n        'gsw':         net shortwave radiation flux at the ground (W/m^2, downward is positive)\n        'hfx':         surface sensible heat flux (W/m^2, upward is positive)\n        'lh':          surface latent heat flux (W/m^2, upward is positive)\n        'tsk':         skin temperature (K)\n        'tslb':        top soil layer temperature (K)\n        'rainc':       rainfall from a cumulus scheme (mm)\n        'rainnc':      rainfall from an explicit scheme (mm)\n        'clw':         total column-integrated water vapor and cloud variables\n\n")

        
        # load the file into a numpy array
        TS = np.genfromtxt(TSfile, skip_header=1, names = col_names)
    
        return TS[variable]

# Name of the RAWS in .csv and ts tslist files, repectively (os.glob is the better option but I worte this quickly :))
##stations = ['ATC10', 'AV313', 'BLD01', 'C7944', 'CO109', 'E3608', 'E5937', 'E6155', 'E9298', 'E9688', 'F0288', 'F0869', 'F2847', 'G0194', 'UP709']
##wrf_stations = ['ATC', 'AV3', 'BLD', 'C79', 'CO1', 'E36', 'E59', 'E61', 'E92', 'E96', 'F02', 'F08', 'F28', 'G01', 'UP7']
#stations = ['ATC10', 'AV313', 'BLD01', 'CO109', 'E6155', 'E9298', 'E9688', 'G0194']
#wrf_stations = ['ATC', 'AV3', 'BLD', 'CO1', 'E61', 'E92', 'E96', 'G01']
#c = ['r','r','orange','orange','b','b','lightskyblue','lightskyblue']
#mark = ['o','^','o','^','o','^','o','^']

stations = ['CO109', 'E9688', 'BLD01', 'G0194', 'E6155', 'ATC10', 'AV313', 'E9298']
wrf_stations = ['CO1', 'E96', 'BLD', 'G01', 'E61', 'ATC', 'AV3', 'E92']
c = ['orange','lightskyblue','orange','lightskyblue','blue','red','red','blue']
mark = ['^','o','o','^','o','o','^','^']
dom = 'd01'

# Loop through all the stations
for ii in range(len(stations)):

    st_id = stations[ii]
    print (st_id)
    wrf_id = wrf_stations[ii]

    csv_file1 = '123021/{}.csv'.format(st_id) # location of csv
    
    wrf_file = '/glade/scratch/tjuliano/people/for_kasra/spotting/' + wrf_id + '.' + dom + '.TS' # location of tslist
    print (wrf_file)
    wrf_file2 = '/glade/scratch/tjuliano/people/for_kasra/no_spotting/' + wrf_id + '.' + dom + '.TS' # location of tslist

    # Open csv for the first date
    weather_station1 = pd.read_csv(csv_file1, header=[6,7])
    weather_station1 = weather_station1[['Date_Time','wind_speed_set_1', 'wind_direction_set_1', 'wind_gust_set_1']]
    weather_station1['Date_Time'] = pd.to_datetime(weather_station1['Date_Time'].stack(), format='%m/%d/%Y %H:%M %Z').unstack()
    weather_station1['time'] = weather_station1['Date_Time'].stack().dt.strftime('%H').astype(float).unstack()
    weather_station1 = weather_station1[weather_station1['time']>=18]
    weather_station1['since'] = (weather_station1['Date_Time'].stack().dt.strftime('%H').astype(float).unstack() + weather_station1['Date_Time'].stack().dt.strftime('%M').astype(float).unstack()/60)-18

    csv_file2 = '123121/{}.csv'.format(st_id)
    weather_station2 = pd.read_csv(csv_file2, header=[6,7])
    weather_station2 = weather_station2[['Date_Time','wind_speed_set_1', 'wind_direction_set_1', 'wind_gust_set_1']]
    weather_station2['Date_Time'] = pd.to_datetime(weather_station2['Date_Time'].stack(), format='%m/%d/%Y %H:%M %Z').unstack()
    weather_station2['time'] = weather_station2['Date_Time'].stack().dt.strftime('%H').astype(float).unstack()
    weather_station2 = weather_station2[weather_station2['time']<=12]
    weather_station2['since'] = (weather_station2['Date_Time'].stack().dt.strftime('%H').astype(float).unstack() + weather_station2['Date_Time'].stack().dt.strftime('%M').astype(float).unstack()/60)+6
    weather_stations = pd.concat([weather_station1,weather_station2],ignore_index=True)    

    # Taking speed and direction and convert speed
    obs_speed = weather_stations[['wind_speed_set_1']].to_numpy()
    obs_speed *= 0.44704
    obs_gust = weather_stations[['wind_gust_set_1']].to_numpy()
    obs_gust *= 0.44704
    obs_dir = weather_stations[['wind_direction_set_1']].to_numpy()
    
    # Time
    time_sim = weather_stations['since'].to_numpy().flatten()

    # Remove data
#    idxx = np.where(time_sim==0)[0]
#    if len(idxx) > 0:
#        print ('Removing data for ', wrf_stations[ii])
#        obs_speed = np.delete(obs_speed,idxx)
#        obs_gust = np.delete(obs_gust,idxx)
#        obs_dir = np.delete(obs_dir,idxx)
#        time_sim = np.delete(time_sim,idxx)

    # Read tslist
    wrf_time = get_ts_data(wrf_file,'ts_hour')
    u = get_ts_data(wrf_file,'u')
    u2 = get_ts_data(wrf_file2,'u')
    v = get_ts_data(wrf_file,'v')
    v2 = get_ts_data(wrf_file2,'v')

    # Convert WRF U and V to speed and direction
    wrf_speed = np.sqrt(np.power(u,2) + np.power(v,2))
    wrf_speed2 = np.sqrt(np.power(u2,2) + np.power(v2,2))
    wrf_dir = (270-np.rad2deg(np.arctan2(v,u)))%360
    wrf_dir2 = (270-np.rad2deg(np.arctan2(v2,u2)))%360
    #print(wrf_speed)

    # Do interpolation
#    print (wrf_time)
#    print (time_sim)
#    f = interpolate.interp1d(wrf_time,wrf_speed)
#    wrf_interp_speed = f(time_sim)

#    results = sm.OLS(obs_speed,sm.add_constant(wrf_interp_speed)).fit()
#    rsquared = results.rsquared

#    print (rsquared)

    wrf_speed_yhat = savgol_filter(wrf_speed, 601, 3)

    ## Plotting :)
    import matplotlib.pyplot as plt
    #plt.style.use('wrf')

    if ii == 0:
        fig1 = plt.figure(figsize=(12,12))
    plt.subplot(8,2,2*ii+1)
    plt.plot(wrf_time, wrf_speed_yhat, 'g', label='WRF Wind Speed',zorder=1)
#    plt.plot(time_sim, obs_speed, 'b', label='OBS Wind Speed')
#    plt.plot(time_sim, obs_gust, 'g', label='OBS Wind Gusts')
    plt.scatter(time_sim, obs_speed, marker=mark[ii], s=50,fc=c[ii], ec='gray',zorder=2)
    plt.scatter(time_sim, obs_gust, marker=mark[ii], s=50,fc=c[ii], ec='k',zorder=2)
    if ii < len(stations)-1:
        plt.xticks([0,3,6,9,12,15,18],['','','','','','',''])
    else:
        plt.xticks([0,3,6,9,12,15,18],fontsize=15)
    plt.yticks(fontsize=15)
    plt.tick_params(axis='y', which='both', labelleft=True, labelright=False)
    plt.grid()

    #plt.title ('Weather Station {}'.format(wrf_id))
    if ii == len(stations)-1:
        plt.xlabel('Time Since 12/30 1800 UTC (hrs)',fontsize=18)
    #if ii == 0:
        #plt.ylabel('Wind Speed (ms$^{-1}$)')
        #plt.legend(loc='upper center', bbox_to_anchor=(0.8, -0.075),
        #      fancybox=True, shadow=True, ncol=3)
    #plt.savefig('figs/{}_speed.png'.format(wrf_id)) # Make sure to create dir directory before running. Can be improved, sorry :)
    #plt.close()

    if ii == 4:
        plt.text(-5,-5,'Wind Speed (ms$^{-1}$)',rotation=90,fontsize=20)

    plt.subplot(8,2,2*ii+2)
    plt.scatter(wrf_time[::500], wrf_dir[::500], marker='+', s=50, c='g', label='WRF Spotting Wind Direction',zorder=1)
    #plt.plot([],[],'r',label='WRF Wind Direction')
    plt.scatter(time_sim, obs_dir, marker=mark[ii], s=50,fc=c[ii], ec='gray', label='Observed Wind Direction',zorder=2)
    #plt.plot([],[],'b', label='Observed Wind Gusts')
    plt.ylim(-5,365)
    if ii < len(stations)-1:
        plt.xticks([0,3,6,9,12,15,18],['','','','','','',''])
    else:
        plt.xticks([0,3,6,9,12,15,18],fontsize=15)
        
    plt.yticks([0,90,180,270,360],fontsize=15)
    plt.tick_params(axis='y', which='both', labelleft=False, labelright=True)
    plt.grid()

    if ii == len(stations)-1:
        plt.xlabel('Time Since 12/30 1800 UTC (hrs)',fontsize=18)

    #plt.title ('Weather Station {}'.format(wrf_id))
    #plt.ylabel('Wind Direction ('+u"\N{DEGREE SIGN}"+')')
    #plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.075),
    #      fancybox=True, shadow=True, ncol=2)

    if ii == 4:
        plt.text(23.5,-45,'Wind Direction ('+u"\N{DEGREE SIGN}"+')',rotation=270,fontsize=20)

    plt.text(-4.5,150,wrf_stations[ii],fontsize=20)

plt.tight_layout()
plt.savefig('figs/compare_obs_wrf_' + dom + '_panel.png',dpi=600,bbox_inches='tight') # Make sure to create dir directory before running. Can be improved, sorry :)
plt.close()




